{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Tensorflow DALI plugin: DALI tf.data.Dataset with multiple GPUs\n",
    "\n",
    "### Overview\n",
    "\n",
    "This notebook is a comprehensive example on how to use DALI `tf.data.Dataset` with multiple GPUs. It is recommended to look into [single GPU example](tensorflow-dataset.ipynb) first to get up to speed with DALI dataset and how it can be used to train a neural network. This example is an extension of the single GPU version.\n",
    "\n",
    "Initially we define some parameters of the training and to create a DALI pipeline to read [MNIST](http://yann.lecun.com/exdb/mnist/) converted to LMDB format. You can find it in [DALI_extra](https://github.com/NVIDIA/DALI_extra) dataset. This pipeline is able to partition the dataset into multiple shards.\n",
    "\n",
    "`DALI_EXTRA_PATH` environment variable should point to the place where data from [DALI extra repository](https://github.com/NVIDIA/DALI_extra) is downloaded. Please make sure that the proper release tag is checked out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nvidia.dali as dali\n",
    "from nvidia.dali import pipeline_def, Pipeline\n",
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "import os\n",
    "\n",
    "import nvidia.dali.plugin.tf as dali_tf\n",
    "import tensorflow as tf\n",
    "\n",
    "import logging\n",
    "\n",
    "tf.get_logger().setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path to MNIST dataset\n",
    "data_path = os.path.join(os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\")\n",
    "\n",
    "BATCH_SIZE = 64\n",
    "DROPOUT = 0.2\n",
    "IMAGE_SIZE = 28\n",
    "NUM_CLASSES = 10\n",
    "HIDDEN_SIZE = 128\n",
    "EPOCHS = 5\n",
    "ITERATIONS = 100\n",
    "NUM_DEVICES = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pipeline_def(batch_size=BATCH_SIZE)\n",
    "def mnist_pipeline(shard_id):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path,\n",
    "        random_shuffle=True,\n",
    "        shard_id=shard_id,\n",
    "        num_shards=NUM_DEVICES,\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\"\n",
    "    )\n",
    "\n",
    "    return images, labels.gpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we create some parameters needed for the DALI dataset. For more details on what they are you can look into [single GPU example](tensorflow-dataset.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "shapes = ((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE), (BATCH_SIZE))\n",
    "dtypes = (tf.float32, tf.int32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to define the model. To make the training distributed to multiple GPUs, we use `tf.distribute.MirroredStrategy`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "strategy = tf.distribute.MirroredStrategy(devices=[\"/gpu:0\", \"/gpu:1\"])\n",
    "\n",
    "with strategy.scope():\n",
    "    model = tf.keras.models.Sequential(\n",
    "        [\n",
    "            tf.keras.layers.Input(\n",
    "                shape=(IMAGE_SIZE, IMAGE_SIZE), name=\"images\"\n",
    "            ),\n",
    "            tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),\n",
    "            tf.keras.layers.Dense(HIDDEN_SIZE, activation=\"relu\"),\n",
    "            tf.keras.layers.Dropout(DROPOUT),\n",
    "            tf.keras.layers.Dense(NUM_CLASSES, activation=\"softmax\"),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    model.compile(\n",
    "        optimizer=\"adam\",\n",
    "        loss=\"sparse_categorical_crossentropy\",\n",
    "        metrics=[\"accuracy\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DALI dataset needs to be distributed as well. To do it, we use `distribute_datasets_from_function`. First we need to define a function that returns dataset bound to a device given by id. Also, some specific options are needed to make everything work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataset_fn(input_context):\n",
    "    with tf.device(\"/gpu:{}\".format(input_context.input_pipeline_id)):\n",
    "        device_id = input_context.input_pipeline_id\n",
    "        return dali_tf.DALIDataset(\n",
    "            pipeline=mnist_pipeline(device_id=device_id, shard_id=device_id),\n",
    "            batch_size=BATCH_SIZE,\n",
    "            output_shapes=shapes,\n",
    "            output_dtypes=dtypes,\n",
    "            device_id=device_id,\n",
    "        )\n",
    "\n",
    "\n",
    "input_options = tf.distribute.InputOptions(\n",
    "    experimental_place_dataset_on_device=True,\n",
    "    experimental_fetch_to_device=False,\n",
    "    experimental_replication_mode=tf.distribute.InputReplicationMode.PER_REPLICA,\n",
    ")\n",
    "\n",
    "train_dataset = strategy.distribute_datasets_from_function(\n",
    "    dataset_fn, input_options\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With everything in place, we are ready to run the training and evaluate the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "100/100 [==============================] - 4s 8ms/step - loss: 1.2438 - accuracy: 0.6290\n",
      "Epoch 2/5\n",
      "100/100 [==============================] - 1s 8ms/step - loss: 0.3991 - accuracy: 0.8876\n",
      "Epoch 3/5\n",
      "100/100 [==============================] - 1s 8ms/step - loss: 0.3202 - accuracy: 0.9045\n",
      "Epoch 4/5\n",
      "100/100 [==============================] - 1s 9ms/step - loss: 0.2837 - accuracy: 0.9183\n",
      "Epoch 5/5\n",
      "100/100 [==============================] - 1s 8ms/step - loss: 0.2441 - accuracy: 0.9303\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f5d09685880>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=ITERATIONS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100/100 [==============================] - 2s 5ms/step - loss: 0.1963 - accuracy: 0.9438\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.19630344212055206, 0.9437500238418579]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(train_dataset, steps=ITERATIONS)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
